import filecmp
import os

import table_ast
import table_lexer
import table_parser

class GoMessageHandler:
    def __init__(self, forModuleName, forPkgName, msgdir, outdir,
                 objectType, handlerFile, msgPrefixs, msgTags, msgMax):
        self.forModuleName, self.forPkgName = forModuleName, forPkgName
        self.msgdir, self.outdir = msgdir, outdir
        self.objectType, self.handlerFile = objectType, handlerFile
        self.msgPrefixs, self.msgTags = msgPrefixs, msgTags
        self.msgMax = msgMax

def to_go_files(indir, outdir, msgHandlers):
    def check_replace_file(filename):
        os.system('gofmt -w %s' % filename)
        if not os.path.exists(filename[:-1]) or \
           not filecmp.cmp(filename, filename[:-1], False):
            os.replace(filename, filename[:-1])
        else:
            os.remove(filename)

    def parse_ast(infile):
        lexer = table_lexer.TableLexer()
        lexer.parse(infile)
        parser = table_parser.TableParser()
        return parser.parse(lexer)

    def parse_blanks(infile):
        blanks = []
        with open(infile, 'r', encoding='utf8') as fo:
            for i, line in enumerate(fo.readlines()):
                if line.isspace():
                    blanks.append(i + 1)
        return blanks

    def generate_go_file(indir, outdir, filename):
        filepart = os.path.splitext(filename)[0]
        outpathpart = os.path.join(outdir, filepart)
        infile = os.path.abspath(os.path.join(indir, filename))
        print('go files %s ...' % infile)
        root, blanks = parse_ast(infile), parse_blanks(infile)
        outfile, prefix = outpathpart+'.go~', 'package protocol\n\n'
        GeneratorGo(blanks).Generate(root, outfile, prefix)
        check_replace_file(outfile)
        return filepart, root

    def generate_message_handler(rootFileList, msgHandler):
        print('go message handler %s ...' % msgHandler.handlerFile)
        if not os.path.exists(msgHandler.outdir):
            os.mkdir(msgHandler.outdir)
        outfile = os.path.join(msgHandler.outdir, msgHandler.handlerFile+'.go~')
        prefix = 'package %s\n\nimport (\n\t"%s/%s"\n\t"%s/fusion"\n)\n' % \
            (msgHandler.forPkgName, msgHandler.forModuleName,
             msgHandler.msgdir, msgHandler.forModuleName)
        GeneratorMHGo().Generate(outfile, rootFileList, msgHandler, prefix)
        check_replace_file(outfile)

    if os.path.exists(indir):
        if not os.path.exists(outdir):
            os.mkdir(outdir)
        rootFileList = {}
        for filename in os.listdir(indir):
            name, root = generate_go_file(indir, outdir, filename)
            rootFileList[name] = root
        for msgHandler in msgHandlers:
            generate_message_handler(rootFileList, msgHandler)

class GeneratorBase:
    @staticmethod
    def _to_go_name(name):
        return name[0].upper() + name[1:]

class GeneratorGo(GeneratorBase):
    def __init__(self, blanks):
        self.blanks, self.i = blanks, 0
        self.lineno = 0

    def Generate(self, root, outfile, prefix = ''):
        with open(outfile, 'w', encoding='utf8') as self.fo:
            self.fo.write(prefix)
            for entity in root.externalDeclarations:
                self.__generateentity(entity, 0)

    def __generateentity(self, entity, level):
        if isinstance(entity, table_ast.EnumDefinition):
            return self.__generateenumblock(entity, level)

        if isinstance(entity, table_ast.EnumDeclaration):
            return self.__generateenummember(entity, level)

        if isinstance(entity, table_ast.Comment):
            return self.__writeline2file(level,
                entity.text.value.rstrip(), entity.text.lineno,
                canInline = True, strSpacer = ' ')

    def __generateenumblock(self, entity, level):
        self.idList, self.valList = [], []
        self.nsVal, self.iVal = '', -1
        self.wpos = None
        self.__writeline2file(level, '// %s\nconst' % entity.name.value,
            entity.name.lineno)
        self.__generateblockmember(entity, level)

    def __generateblockmember(self, entity, level):
        self.__writeline2file(level, '(',
            entity.name.lineno, canInline = True, strSpacer = ' ')
        for member in entity.declarationList.declarationList:
            self.__generateentity(member, level + 1)
        self.__writeline2file(level, ')')

    def __generateenummember(self, member, level):
        self.__writeenummember(self._to_go_name(member.memName.value),
            self.__enumValue2formatStr(member.memValue), level,
            member.memName.lineno)

    def __writeenummember(self, idStr, valStr, level, lineno):
        if self.lineno != lineno:
            self.wpos = self.fo.tell()
            del self.idList[:], self.valList[:]
        self.fo.seek(self.wpos)
        self.idList.append(idStr), self.valList.append(valStr)
        self.__writeline2file(level,
            '%s = %s' % (', '.join(self.idList), ', '.join(self.valList)),
            lineno)

    def __writeline2file(self, level, linedata, lineno = -1, **kwargs):
        while len(self.blanks) > self.i and self.blanks[self.i] < lineno:
            self.fo.write('\n')
            self.i += 1
        if kwargs and kwargs['canInline'] and self.lineno == lineno:
            self.fo.seek(self.fo.tell() - len(os.linesep))
            self.fo.write(kwargs['strSpacer'])
        elif level != 0:
            self.fo.write('\t' * level)
        self.fo.write(linedata)
        self.fo.write('\n')
        if lineno != -1:
            self.lineno = lineno

    def __enumValue2formatStr(self, node):
        if isinstance(node, table_ast.NsIdList):
            self.nsVal, self.iVal = self.__to_go_enum_value(node), 0
        elif node:
            self.nsVal, self.iVal = '', int(node)
        else:
            self.iVal += 1
        if self.nsVal and self.iVal != 0:
            return '%s+%d' % (self.nsVal, self.iVal)
        elif self.nsVal:
            return self.nsVal
        else:
            return str(self.iVal)

    @classmethod
    def __to_go_enum_value(cls, node):
        return cls._to_go_name(node.idList[-1].value)

class GeneratorMHGo(GeneratorBase):
    def Generate(self, outfile, rootFileList, msgHandler, prefix = ''):
        with open(outfile, 'w', encoding='utf8') as self.fo:
            msgs = self.__filter_message(
                rootFileList, msgHandler.msgPrefixs, msgHandler.msgTags)
            self.fo.write(prefix)
            self.fo.write(('\nvar %sHandlers = ' +
                '[protocol.%s]func(*%s, *fusion.NetPacket) int{\n') %
                (msgHandler.objectType, self._to_go_name(msgHandler.msgMax),
                 msgHandler.objectType))
            self.__write_message_handlers(msgs, msgHandler, False)
            self.fo.write('}\n')
            self.fo.write(('\nvar %sRPCHandlers = [protocol.%s]' +
                'func(*%s, *fusion.NetPacket, *fusion.RPCReqMetaInfo) int{\n') %
                (msgHandler.objectType, self._to_go_name(msgHandler.msgMax),
                 msgHandler.objectType))
            self.__write_message_handlers(msgs, msgHandler, True)
            self.fo.write('}\n')

    def __write_message_handlers(self, msgs, msgHandler, isRPC = False):
            for msg in msgs:
                if (not msg.rpc and not isRPC) or (msg.rpc and isRPC):
                    self.fo.write('\tprotocol.%s: (*%s).%s,\n' %
                        (self._to_go_name(msg.parts[1]), msgHandler.objectType,
                         self.__beautify_message_name(msg.parts[1])))

    class MsgInfo:
        def __init__(self, rpc, *parts):
            assert len(parts) == 2, 'msg part number must equal 2.'
            self.rpc, self.parts = rpc, parts

    @staticmethod
    def __beautify_message_name(msgName):
        parts = [part.title() for part in msgName.split('_')]
        return 'handle' + ''.join(parts[1 if len(parts) > 1 else 0 :])

    @classmethod
    def __filter_message(cls, rootFileList, msgPrefixs, msgTags):
        msgs = []
        for name, root in rootFileList.items():
            for entity in root.externalDeclarations:
                if not isinstance(entity, table_ast.EnumDefinition):
                    continue
                for i, member in enumerate(entity.declarationList.declarationList):
                    if not isinstance(member, table_ast.EnumDeclaration):
                        continue
                    tags = cls.__get_message_tags(entity.
                        declarationList.declarationList[i+1:], member.memName.lineno)
                    if not cls.__check_message_available(
                        member.memName.value, tags, msgPrefixs, msgTags):
                        continue
                    msgs.append(cls.MsgInfo(tags and 'rpc' in tags or False,
                        entity.name.value, member.memName.value))
        return msgs

    @staticmethod
    def __get_message_tags(declarationList, lineno):
        for member in declarationList:
            if isinstance(member, table_ast.EnumDeclaration):
                if member.memName.lineno <= lineno:
                    continue
            if isinstance(member, table_ast.Comment):
                if member.text.lineno == lineno:
                    try:
                        return member.text.tags
                    except AttributeError as e:
                        pass
            return

    @staticmethod
    def __check_message_available(memName, tags, msgPrefixs, msgTags):
        if tags and (len(tags) > 1 or tags[0] != 'rpc'):
            return 'ingore' not in tags and \
                next((True for msgTag in msgTags if msgTag in tags), False)
        else:
            return next((True for msgPrefix in \
                msgPrefixs if memName.startswith(msgPrefix + '_')), False)
